import argparse
import copy
import os
import random
import time
import pdb
import numpy as np
import torch
from mmcv import Config
from torchpack import distributed as dist
from torchpack.environ import auto_set_run_dir, set_run_dir
from torchpack.utils.config import Config as CFG
from mmcv.runner import load_checkpoint

from mmdet3d.apis import train_model,distil_train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_model
from mmdet3d.utils import get_root_logger, convert_sync_batchnorm, recursive_eval


def main():
    dist.init()

    parser = argparse.ArgumentParser()
    parser.add_argument("--config_s", metavar="FILE", help="student config file")
    parser.add_argument("--config_t", metavar="FILE", help="teacher config file")
    parser.add_argument("--run-dir", metavar="DIR", help="run directory")
    parser.add_argument("--checkpoint", metavar="FILE", help="teacher checkpoint")
    args, opts = parser.parse_known_args()
    
    #student config
    config_s=CFG()
    config_s.load(args.config_s, recursive=True)
    #configs.update(opts)
    cfg = Config(recursive_eval(config_s), filename=args.config_s)
    #configs.update(opts)
    
    #teacher config
    config_t=CFG()
    config_t.load(args.config_t, recursive=True)
    cfg_t = Config(recursive_eval(config_t), filename=args.config_t)
    
    
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
    torch.cuda.set_device(dist.local_rank())
    

    if args.run_dir is None:
        args.run_dir = auto_set_run_dir()
    else:
        set_run_dir(args.run_dir)
    cfg.run_dir = args.run_dir

    # dump config
    cfg.dump(os.path.join(cfg.run_dir, "configs.yaml"))

    # init the logger before other steps
    timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    log_file = os.path.join(cfg.run_dir, f"{timestamp}.log")
    logger = get_root_logger(log_file=log_file)

    # log some basic info
    #logger.info(f"Config:\n{cfg.pretty_text}")

    # set random seeds
    if cfg.seed is not None:
        logger.info(
            f"Set random seed to {cfg.seed}, "
            f"deterministic mode: {cfg.deterministic}"
        )
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        if cfg.deterministic:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

    datasets = build_dataset(cfg.data.train)    #smodel
    model = build_model(cfg.model)
    #tmodel
    tmodel = build_model(cfg_t.model)
    
    
    #tmodel load checkpoint
    checkpoint = load_checkpoint(tmodel, args.checkpoint, map_location="cpu")
    # if "CLASSES" in checkpoint.get("meta", {}):
    #     model.CLASSES = checkpoint["meta"]["CLASSES"]
    # else:
    #     model.CLASSES = datasets.CLASSES
        
        
    
    model.init_weights()
    if cfg.get("sync_bn", None):
        if not isinstance(cfg["sync_bn"], dict):
            cfg["sync_bn"] = dict(exclude=[])
        model = convert_sync_batchnorm(model, exclude=cfg["sync_bn"]["exclude"])

    logger.info(f"Model:\n{model}")
    # train_model(
    #     model,
    #     datasets,
    #     cfg,
    #     distributed=True,
    #     validate=True,
    #     timestamp=timestamp,
    # )        
    distil_train_model(
        model,
        tmodel,
        datasets,
        cfg,
        distributed=True,
        validate=True,
        timestamp=timestamp,
    )
        


if __name__ == "__main__":
    main()
